import ij.*;
import ij.plugin.PlugIn;
import ij.process.*;
import ij.measure.Calibration;
import ij.gui.GenericDialog;
import ij.gui.StackWindow;
import java.util.*;

/* Processes a stack or hyperstack using a standard or enhanced 3x3 hybrid median filter. */
public class Filter_Hybrid_Median implements PlugIn {
  
	private ImagePlus originalImage, filteredImage, extractedZStack, filteredExtractedZStack;
	private ImageProcessor original, filtered, extracted, filteredExtracted;
	private Calibration copyCal;
	private byte[] pixels;
	private boolean[] filterChannels;                              // Indicators of which channels will be filtered.
	private final String[] filterTypes = new String[] {"2D", "2D extended", "3D"};
	private String filterType;
	private int filterCount;                                       // Iterations of the filter.
	
	// Initialize arrays that can be repeatedly re-used.
	private static int[] values3 = new int[3];             
	private static int[] values5 = new int[5];
	private static int[] values7 = new int[7];
	private static int[] values9 = new int[9];
	private static byte[] threeByThree = new byte[9];
	
	private int width, height, channels, slices, frames, images;   // Stack or hyperstack parameters.
	private String windowTitle;
	
	//------------------------------------------------------------------------------------------------------------------------  
	
    public void run(String arg) {
      
      originalImage = IJ.getImage();
      if (originalImage == null) return;
      
      originalImage.killRoi();
      
      if (originalImage.getBitDepth() == 16) {
        new StackConverter(originalImage).convertToGray8();         // Convert 16-bit data to 8-bit.
      }
      else if (originalImage.getBitDepth() != 8) {
        IJ.showMessage("This plugin requires 8-bit or 16-bit images.");
        return;
      }

      original = originalImage.getProcessor();
      
      IJ.resetEscape();
      
      width = originalImage.getWidth();
      height = originalImage.getHeight();
      channels = originalImage.getNChannels();
      slices = originalImage.getNSlices();
      frames = originalImage.getNFrames();
      images = channels * slices * frames;
      
      GenericDialog gd = new GenericDialog("Parameters");
      gd.addChoice("Hybrid median filter:", filterTypes, filterTypes[0]);
      gd.addNumericField("Filter iterations:", 1, 0);
      for (int i = 1; i <= channels; i++) {
        gd.addCheckbox("Filter channel " + i, true);
      }
      gd.showDialog();
      if (gd.wasCanceled()) return;
      
      filterType = gd.getNextChoice();
      filterCount = (int)gd.getNextNumber();
      filterChannels = new boolean[channels];
      for (int i = 1; i <= channels; i++) {
        filterChannels[i - 1] = gd.getNextBoolean();
      }
      
      windowTitle = originalImage.getTitle();
      if (windowTitle.endsWith(".tif") || windowTitle.endsWith(".TIF")) {
        windowTitle = windowTitle.substring(0, windowTitle.length() - 4);
      }
      
      copyCal = originalImage.getCalibration().copy();

      filteredImage = IJ.createImage(windowTitle + " Filtered", "8-bit black", width, height, images);
      filteredImage.setDimensions(channels, slices, frames);
      if (channels > 1 || frames > 1)
        filteredImage.setOpenAsHyperStack(true);
      filteredImage.setCalibration(copyCal);
      
      // Ensure that the display parameters are preserved in the filtered image.
      if (originalImage.isComposite()) {
        int mode = ((CompositeImage)originalImage).getMode();
        if (channels > 1) {
          filteredImage = new CompositeImage(filteredImage, mode);
          for (int i=1; i <= channels; i++) {
            LUT lut = ((CompositeImage)originalImage).getChannelLut(i);
            ((CompositeImage)filteredImage).setChannelLut(lut, i);
          }
        }
        else {
          LUT lut = ((CompositeImage)originalImage).getChannelLut(1);
          filteredImage.getProcessor().setColorModel(lut);
          filteredImage.setDisplayRange(lut.min, lut.max);
        }
      }
      
      filtered = filteredImage.getProcessor();
      
      // Create images to hold an extracted Z-stack and a filtered version of the extracted Z-stack.
      extractedZStack = IJ.createImage("extractedZStack", "8-bit black", width, height, slices);
      extracted = extractedZStack.getProcessor();
      filteredExtractedZStack = IJ.createImage("filteredExtractedZStack", "8-bit black", width, height, slices);
      filteredExtracted = filteredExtractedZStack.getProcessor();
      
      // Extract a Z-stack for each channel at each time point, and filter the Z-stack if appropriate.
      for (int timePoint = 1; timePoint <= frames; timePoint++) {
        IJ.showStatus("Filtering " + timePoint + " of " + frames + " Z-stacks. Press Esc to abort.");
        IJ.showProgress((double) timePoint / frames);
        
        for (int ch = 1; ch <= channels; ch++) {
          if (IJ.escapePressed()) {
            filteredImage.flush();
            IJ.showStatus("Plugin aborted.");
            IJ.showProgress(0.0);
            return;
          }
          
          extractedZStack = extractZStack(original, timePoint, ch);
          extracted = extractedZStack.getProcessor();
          
          if (filterChannels[ch - 1]) {
            filteredExtractedZStack = filterZStack(extractedZStack);
            filteredExtracted = filteredExtractedZStack.getProcessor();
          }
          
          // Copy the filtered or unfiltered Z-stack to the appropriate slices in the final filtered image.
          for (int i = 1; i <= slices; i++) {
            if (filterChannels[ch - 1]) {
              filteredExtractedZStack.setSlice(i);
              pixels = (byte[]) filteredExtracted.getPixelsCopy();
            }
            else {
              extractedZStack.setSlice(i);
              pixels = (byte[]) extracted.getPixelsCopy();
            }
            
            int firstSlice = channels * slices * (timePoint - 1) + ch;
            filteredImage.setSlice(firstSlice + (i - 1) * channels);
            filtered.setPixels(pixels);
          }
        }
      }
      
      originalImage.setSlice(1);
      originalImage.updateAndRepaintWindow();
      new StackWindow(originalImage).updateSliceSelector();
      
      filteredImage.setSlice(1);
      filteredImage.show();
      
	}

	//========================================================================================================================
    
    /* Returns a new Z-stack of images copied at the specified time point from channel ch. */
    private ImagePlus extractZStack(ImageProcessor original, int timePoint, int ch) { 
      
      int firstSlice = channels * slices * (timePoint - 1) + ch;
      
      for (int i = 1; i <= slices; i++) {
        originalImage.setSlice(firstSlice + (i - 1) * channels);
        pixels = (byte[]) original.getPixelsCopy();
        extractedZStack.setSlice(i);
        extracted.setPixels(pixels);
      }
      
      return extractedZStack;
    }   

    //========================================================================================================================
    
    /* Processes an extracted Z-stack, repeatedly if desired, using a standard or enhanced hybrid median filter. */
    private ImagePlus filterZStack(ImagePlus extractedZStack) {
      
      int number;
      byte[] filteredPixels = new byte[width * height];
      byte[] pixelsAbove = new byte[width * height];
      byte[] pixelsBelow = new byte[width * height];

      for (int count = 1; count <= filterCount; count++) {
        for (int i = 1; i <= slices; i++) {
          
          // Get the pixels for the current slice, and optionally for the adjacent slices as well.
          if ((filterType.equals("2D extended") || filterType.equals("3D")) && slices >= 3) {
            if (i == 1) {
              extractedZStack.setSlice(1);
              pixels = (byte[]) extracted.getPixelsCopy();
              extractedZStack.setSlice(2);
              pixelsAbove = (byte[]) extracted.getPixelsCopy();
            }
            else {
              pixelsBelow = pixels;
              pixels = pixelsAbove;
              if (i < slices) {
                extractedZStack.setSlice(i + 1);
                pixelsAbove = (byte[]) extracted.getPixelsCopy();
              }
            }
          }
          else {
            extractedZStack.setSlice(i);
            pixels = (byte[]) extracted.getPixelsCopy();
          }

          if (count == 1) {   // Filter corners and edges during the first iteration.
            
            // Filter each of the four corners.
            number = 0;
            filteredPixels[number] = filterCorner(pixels, number, width, "TOP LEFT");
            
            number = width - 1;
            filteredPixels[number] = filterCorner(pixels, number, width, "TOP RIGHT");
            
            number = width * (height - 1);
            filteredPixels[number] = filterCorner(pixels, number, width, "BOTTOM LEFT");
            
            number = (width * height) - 1;
            filteredPixels[number] = filterCorner(pixels, number, width, "BOTTOM RIGHT");
            
            // Filter the top edge.
            for (int x = 1; x < width - 1; x++) {
              number = x;
              filteredPixels[number] = filterEdge(pixels, number, width, "TOP");
            }
            
            // Filter the bottom edge.
            for (int x = 1; x < width - 1; x++) {
              number = x + width * (height - 1);
              filteredPixels[number] = filterEdge(pixels, number, width, "BOTTOM");
            }
            
            // Filter the left edge.
            for (int y = 1; y < height - 1; y++) {
              number = y * width;
              filteredPixels[number] = filterEdge(pixels, number, width, "LEFT");
            }
            
            // Filter the right edge.
            for (int y = 1; y < height - 1; y++) {
              number = (y * width) + width - 1;
              filteredPixels[number] = filterEdge(pixels, number, width, "RIGHT");
            }
          }
          
          else {   // Copy existing corner and edge pixels during subsequent iterations.
            
            filteredPixels[0] = pixels[0];
            filteredPixels[width - 1] = pixels[width - 1];
            filteredPixels[width * (height - 1)] = pixels[width * (height - 1)];
            filteredPixels[(width * height) - 1] = pixels[(width * height) - 1];
            
            for (int x = 1; x < width - 1; x++) {
              number = x;
              filteredPixels[number] = pixels[number];
            }
            for (int x = 1; x < width - 1; x++) {
              number = x + width * (height - 1);
              filteredPixels[number] = pixels[number];
            }
            for (int y = 1; y < height - 1; y++) {
              number = y * width;
              filteredPixels[number] = pixels[number];
            }
            for (int y = 1; y < height - 1; y++) {
              number = (y * width) + width - 1;
              filteredPixels[number] = pixels[number];
            }
          }
          
          // Filter the interior pixels with a standard or enhanced hybrid median filter.
          for (int y = 1; y < height - 1; y++) {
            for (int x = 1; x < width - 1; x++) {
              number = x + (y * width);
              if (filterType.equals("2D") || (i == 1) || (i == slices))
                filteredPixels[number] = filterHybridMedian(pixels, number, width);
              else if (filterType.equals("2D extended"))
                filteredPixels[number]  = filterExtendedHybridMedian(pixels, number, width, pixelsBelow[number], pixelsAbove[number]);
              else if (filterType.equals("3D"))
                filteredPixels[number]  = filter3DHybridMedian(pixels, number, width, pixelsBelow, pixelsAbove);
            }
          }
          
          // Copy the filtered pixels to the appropriate slice in filteredExtractedZStack.
          byte[] copyFilteredPixels = new byte[filteredPixels.length];
          System.arraycopy(filteredPixels, 0, copyFilteredPixels, 0, filteredPixels.length);
          filteredExtractedZStack.setSlice(i);
          filteredExtracted.setPixels(copyFilteredPixels);
        }
        
        // For iterations of the filter, copy the contents of filteredExtractedZStack to extractedZStack.
        if (count < filterCount) {
          for (int i = 1; i <= slices; i++) {
            filteredExtractedZStack.setSlice(i);
            pixels = (byte[]) filteredExtracted.getPixelsCopy();
            extractedZStack.setSlice(i);
            extracted.setPixels(pixels);
          }
        }
      }
      
      return filteredExtractedZStack;
    }
    
    //========================================================================================================================
    
    /* Converts a signed byte to an unsigned integer. */
    private int toInt(byte signed) {
      return (int) signed & 0xFF;
    }
    
    //------------------------------------------------------------------------------------------------------------------------ 
    
    /* Returns a local median, giving double weight to the corner pixel. */
    private byte filterCorner(byte[] px, int number, int width, String corner) {
 
      if (corner.equals("TOP LEFT")) {
        values5[0] = toInt(px[number]);
        values5[1] = toInt(px[number]);
        values5[2] = toInt(px[number + 1]);
        values5[3] = toInt(px[number + width]);
        values5[4] = toInt(px[number + width + 1]);
      }
      else if (corner.equals("TOP RIGHT")) {
        values5[0] = toInt(px[number]);
        values5[1] = toInt(px[number - 1]);
        values5[2] = toInt(px[number]);
        values5[3] = toInt(px[number + width - 1]);
        values5[4] = toInt(px[number + width]);
      }
      else if (corner.equals("BOTTOM LEFT")) {
        values5[0] = toInt(px[number]);
        values5[1] = toInt(px[number - width]);
        values5[2] = toInt(px[number - width + 1]);
        values5[3] = toInt(px[number]);
        values5[4] = toInt(px[number + 1]);
      }
      else if (corner.equals("BOTTOM RIGHT")) {
        values5[0] = toInt(px[number]);
        values5[1] = toInt(px[number - width - 1]);
        values5[2] = toInt(px[number - width]);
        values5[3] = toInt(px[number - 1]);
        values5[4] = toInt(px[number]);
      }
      Arrays.sort(values5);                    // Median value is now #2.
      return (byte)values5[2];
    }
        
    //------------------------------------------------------------------------------------------------------------------------ 
    
    /* Returns a local median, giving double weight to the current edge pixel. */
    private byte filterEdge(byte[] px, int number, int width, String edge) {

      if (edge.equals("TOP")) {
        values7[0] = toInt(px[number]);
        values7[1] = toInt(px[number - 1]);
        values7[2] = toInt(px[number]);
        values7[3] = toInt(px[number + 1]);
        values7[4] = toInt(px[number + width - 1]);
        values7[5] = toInt(px[number + width]);
        values7[6] = toInt(px[number + width + 1]);
      }
      else if (edge.equals("BOTTOM")) {
        values7[0] = toInt(px[number]);
        values7[1] = toInt(px[number - width - 1]);
        values7[2] = toInt(px[number - width]);
        values7[3] = toInt(px[number - width + 1]);
        values7[4] = toInt(px[number - 1]);
        values7[5] = toInt(px[number]);
        values7[6] = toInt(px[number + 1]);
      }
      else if (edge.equals("LEFT")) {
        values7[0] = toInt(px[number]);
        values7[1] = toInt(px[number - width]);
        values7[2] = toInt(px[number]);
        values7[3] = toInt(px[number + width]);
        values7[4] = toInt(px[number - width + 1]);
        values7[5] = toInt(px[number + 1]);
        values7[6] = toInt(px[number + width + 1]);
      }
      else if (edge.equals("RIGHT")) {
        values7[0] = toInt(px[number]);
        values7[1] = toInt(px[number - width - 1]);
        values7[2] = toInt(px[number - 1]);
        values7[3] = toInt(px[number + width - 1]);
        values7[4] = toInt(px[number - width]);
        values7[5] = toInt(px[number]);
        values7[6] = toInt(px[number + width]);
      }
      Arrays.sort(values7);                   // Median value is now #3.
      return (byte)values7[3];
    }
        
    //------------------------------------------------------------------------------------------------------------------------    
    
    /* Returns a 3x3 standard hybrid median of an interior pixel. */
    private byte filterHybridMedian(byte[] px, int number, int width) {
      
      int centralPixel = toInt(px[number]);
      
      // Calculate diagonal median as values3[0].
      values5[0] = toInt(px[number - width - 1]);
      values5[1] = toInt(px[number - width + 1]);
      values5[2] = centralPixel;
      values5[3] = toInt(px[number + width - 1]);
      values5[4] = toInt(px[number + width + 1]);
      Arrays.sort(values5);                       // Median value is now #2.
      values3[0] = values5[2];
      
      // Calculate straight median as values3[1].
      values5[0] = toInt(px[number - width]);
      values5[1] = toInt(px[number - 1]);
      values5[2] = centralPixel;
      values5[3] = toInt(px[number + 1]);
      values5[4] = toInt(px[number + width]);
      Arrays.sort(values5);                       // Median value is now #2.
      values3[1] = values5[2];
      
      values3[2] = centralPixel;
      
      Arrays.sort(values3);                       // Median value is now #1.
      return (byte)values3[1];
    }
    
    //------------------------------------------------------------------------------------------------------------------------    
    
    /* Returns a 3x3 extended hybrid median of an interior pixel. */
    private byte filterExtendedHybridMedian(byte[] px, int number, int width, byte below, byte above) {
      
      int centralPixel = toInt(px[number]);
      int[] vertical = new int[3];              // Array that includes the pixels above and below.
      
      // Calculate diagonal median as values3[0].
      values5[0] = toInt(px[number - width - 1]);
      values5[1] = toInt(px[number - width + 1]);
      values5[2] = centralPixel;
      values5[3] = toInt(px[number + width - 1]);
      values5[4] = toInt(px[number + width + 1]);
      Arrays.sort(values5);                       // Median value is now #2.
      values3[0] = values5[2];
      
      // Calculate straight median as values3[1].
      values5[0] = toInt(px[number - width]);
      values5[1] = toInt(px[number - 1]);
      values5[2] = centralPixel;
      values5[3] = toInt(px[number + 1]);
      values5[4] = toInt(px[number + width]);
      Arrays.sort(values5);                       // Median value is now #2.
      values3[1] = values5[2];
      
      // Calculate median along the Z-axis as values3[2].
      vertical[0] = toInt(below);
      vertical[1] = centralPixel;
      vertical[2] = toInt(above);
      Arrays.sort(vertical);                      // Median value is now #1.
      values3[2] = vertical[1];
      
      Arrays.sort(values3);                       // Median value is now #1.
      return (byte)values3[1];
    }
    
    //------------------------------------------------------------------------------------------------------------------------    
    
    /* Returns a full 3D 3x3 hybrid median of an interior pixel. */
    private byte filter3DHybridMedian(byte[] px, int number, int width, byte[] pxBelow, byte[] pxAbove) {

      byte centralPixel = px[number];

      // The first of the nine values is the standard hybrid median.
      values9[0] = toInt(filterHybridMedian(px, number, width));

      // The other eight values are the hybrid medians from the other symmetry plane sections.
      for (int i = 1; i <= 8; i++) {

        if (i == 1) {   
          threeByThree[0] = pxAbove[number - 1];
          threeByThree[1] = pxAbove[number];
          threeByThree[2] = pxAbove[number + 1];
          threeByThree[3] = px[number - 1];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number + 1];
          threeByThree[6] = pxBelow[number - 1];
          threeByThree[7] = pxBelow[number];
          threeByThree[8] = pxBelow[number + 1];
        }
        else if (i == 2) {
          threeByThree[0] = pxAbove[number - width];
          threeByThree[1] = pxAbove[number];
          threeByThree[2] = pxAbove[number + width];
          threeByThree[3] = px[number - width];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number + width];
          threeByThree[6] = pxBelow[number - width];
          threeByThree[7] = pxBelow[number];
          threeByThree[8] = pxBelow[number + width];
        }
        else if (i == 3) {
          threeByThree[0] = pxAbove[number - width - 1];
          threeByThree[1] = pxAbove[number];
          threeByThree[2] = pxAbove[number - width + 1];
          threeByThree[3] = px[number - width - 1];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number - width + 1];
          threeByThree[6] = pxBelow[number - width - 1];
          threeByThree[7] = pxBelow[number];
          threeByThree[8] = pxBelow[number - width + 1];
        }
        else if (i == 4) {
          threeByThree[0] = pxAbove[number + width - 1];
          threeByThree[1] = pxAbove[number];
          threeByThree[2] = pxAbove[number - width + 1];
          threeByThree[3] = px[number + width - 1];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number - width + 1];
          threeByThree[6] = pxBelow[number + width - 1];
          threeByThree[7] = pxBelow[number];
          threeByThree[8] = pxBelow[number - width + 1];
        }
        else if (i == 5) {
          threeByThree[0] = pxAbove[number + width - 1];
          threeByThree[1] = pxAbove[number - 1];
          threeByThree[2] = pxAbove[number - width - 1];
          threeByThree[3] = px[number + width];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number - width];
          threeByThree[6] = pxBelow[number + width + 1];
          threeByThree[7] = pxBelow[number + 1];
          threeByThree[8] = pxBelow[number - width + 1];
        }
        else if (i == 6) {
          threeByThree[0] = pxAbove[number - width + 1];
          threeByThree[1] = pxAbove[number + 1];
          threeByThree[2] = pxAbove[number + width + 1];
          threeByThree[3] = px[number - width];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number + width];
          threeByThree[6] = pxBelow[number - width - 1];
          threeByThree[7] = pxBelow[number - 1];
          threeByThree[8] = pxBelow[number + width - 1];
        }
        else if (i == 7) {
          threeByThree[0] = pxAbove[number - width - 1];
          threeByThree[1] = pxAbove[number - width];
          threeByThree[2] = pxAbove[number - width + 1];
          threeByThree[3] = px[number - 1];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number + 1];
          threeByThree[6] = pxBelow[number + width - 1];
          threeByThree[7] = pxBelow[number + width];
          threeByThree[8] = pxBelow[number + width + 1];
        }
        else if (i == 8) {
          threeByThree[0] = pxAbove[number + width + 1];
          threeByThree[1] = pxAbove[number + width];
          threeByThree[2] = pxAbove[number + width - 1];
          threeByThree[3] = px[number + 1];
          threeByThree[4] = centralPixel;
          threeByThree[5] = px[number - 1];
          threeByThree[6] = pxBelow[number - width + 1];
          threeByThree[7] = pxBelow[number - width];
          threeByThree[8] = pxBelow[number - width - 1];
        }

        // Find the hybrid median of the 3x3 section.
        values9[i] = toInt(filterHybridMedian(threeByThree, 4, 3));
      }
      
      // Find the median of these nine hybrid median values.
      Arrays.sort(values9);                      // Median value is now #4.
      return (byte)values9[4];
    }

}
